package org.test;

import scala.collection.Iterator;
import scala.collection.mutable.WrappedArray;

import java.util.List;
import java.util.Map;
import java.util.TreeMap;

/**
 * 数值计算util,配合spark udf使用
 */
public class NumCalcUtil {
    /**
     * 求均值
     */
    public static double getMean(WrappedArray<Double> array) {
        double sum = 0;
        Iterator<Double> iter = array.iterator();
        while (iter.hasNext()) {
            sum += iter.next();
        }
        return sum / array.size();
    }

    /**
     * 求均值
     */
    public static double getMean(List<Double> array) {
        double sum = 0;
        for (Double value : array) {
            sum += value;
        }
        return sum / array.size();
    }

    /**
     * 求方差
     */
    public static double getVariance(WrappedArray<Double> array) {
        double var = 0;
        double mean = getMean(array);
        Iterator<Double> iter = array.iterator();
        while (iter.hasNext()) {
            double value = iter.next();
            var += (value - mean) * (value - mean);
        }
        return var / (array.size() - 1);
    }

    /**
     * 求方差
     */
    public static double getVariance(List<Double> array) {
        double var = 0;
        double mean = getMean(array);
        for (Double value : array) {
            var += (value - mean) * (value - mean);
        }
        return var / (array.size() - 1);
    }

    /**
     * 求标准差
     */
    public static double getStdDeviation(WrappedArray<Double> array) {
        return Math.sqrt(getVariance(array));
    }

    /**
     * 求标准差
     */
    public static double getStdDeviation(List<Double> array) {
        return Math.sqrt(getVariance(array));
    }

    /**
     * 求协方差
     */
    public static double getCovariance(WrappedArray<Double> array1, WrappedArray<Double> array2) {
        int size = array1.size();
        double cov = 0;

        Double mean1 = getMean(array1);
        Double mean2 = getMean(array2);
        for (int i = 0; i < size; i++) {
            cov += (array1.apply(i) - mean1) * (array2.apply(i) - mean2);
        }

        return cov / (size - 1);
    }

    public static double getCovariance(List<Double> array1, List<Double> array2) {
        int size = array1.size();
        double cov = 0;

        Double mean1 = getMean(array1);
        Double mean2 = getMean(array2);
        for (int i = 0; i < size; i++) {
            cov += (array1.get(i) - mean1) * (array2.get(i) - mean2);
        }

        return cov / (size - 1);
    }

    /**
     * 传入均值求协方差
     */
    public static double getCovariance(List<Double> array1, List<Double> array2, Double mean1, Double mean2) {
        int size = array1.size();
        double cov = 0;

        for (int i = 0; i < size; i++) {
            cov += (array1.get(i) - mean1) * (array2.get(i) - mean2);
        }

        return cov / (size - 1);
    }

    /**
     * 求区间收益率
     */
    public static double getIntervalYield(WrappedArray<Double> array) {
        double intvComRate = 1;
        Iterator<Double> iter = array.iterator();
        while (iter.hasNext()) {
            intvComRate *= (1 + iter.next());
        }
        return intvComRate;
    }

    public static double getIntervalYield(List<Double> array) {
        double intvComRate = 1;
        for (Double value : array) {
            intvComRate *= (1 + value);
        }
        return intvComRate;
    }

    /**
     * 求年化值
     */
    public static double getYeared1(Double value, String freq) {
        switch (freq) {
            case "D":
                value *= Math.sqrt(365);
                break;
            case "W":
                value *= Math.sqrt(52);
                break;
            case "M":
                value *= Math.sqrt(12);
                break;
            default:
                break;
        }
        return value;
    }

    /**
     * 求区间收益率年化值365年化方式
     */
    public static double getYeared2(Double value, String freq, int size) {
        switch (freq) {
            case "D":
                value = Math.pow(value, 365.0 / size);
                break;
            case "W":
                value = Math.pow(value, 52.0 / size);
                break;
            case "M":
                value = Math.pow(value, 12.0 / size);
                break;
            default:
                break;
        }
        return value;
    }

    /**
     * 求区间收益率年化值250年化方式
     */
    public static double getYeared3(Double value, String freq, int size) {
        switch (freq) {
            case "D":
                value = Math.pow(value, 250.0 / size);
                break;
            case "W":
                value = Math.pow(value, 52.0 / size);
                break;
            case "M":
                value = Math.pow(value, 12.0 / size);
                break;
            default:
                break;
        }
        return value;
    }

    /**
     * 求区间虚拟单位净值
     */
    public static double[] getVirtualUnitNet(WrappedArray<Double> array1, WrappedArray<String> array2) {
        int size = array1.size();
        TreeMap<String, Double> map = new TreeMap<>();

        for (int i = 0; i < size; i++) {
            map.put(array2.apply(i), array1.apply(i));
        }

        int i = 0;
        double[] virNet = new double[size];
        for (Double value : map.values()) {
            virNet[i] = i == 0 ? (value + 1) : virNet[i - 1] * (value + 1);
            i++;
        }

        return virNet;
    }

    public static double[] getVirtualUnitNet(List<Double> array1, List<String> array2) {
        int size = array1.size();
        TreeMap<String, Double> map = new TreeMap<>();

        for (int i = 0; i < size; i++) {
            map.put(array2.get(i), array1.get(i));
        }

        int i = 0;
        double[] virNet = new double[size];
        for (Double value : map.values()) {
            virNet[i] = i == 0 ? (value + 1) : virNet[i - 1] * (value + 1);
            i++;
        }

        return virNet;
    }

    /**
     * 求区间虚拟单位净值
     */
    public static TreeMap<String, Double> getVirtualUnitNetMap(WrappedArray<Double> array1, WrappedArray<String> array2) {
        int size = array1.size();
        TreeMap<String, Double> map = new TreeMap<>();
        for (int i = 0; i < size; i++) {
            map.put(array2.apply(i), array1.apply(i));
        }

        double virNetPre = 1;
        for (String key : map.keySet()) {
            double virNet = virNetPre*(map.get(key) + 1);
            virNetPre = virNet;
            map.put(key, virNet);
        }

        return map;
    }

    public static TreeMap<String, Double> getVirtualUnitNetMap(List<Double> array1, List<String> array2) {
        int size = array1.size();
        TreeMap<String, Double> map = new TreeMap<>();
        for (int i = 0; i < size; i++) {
            map.put(array2.get(i), array1.get(i));
        }

        double virNetPre = 1;
        for (String key : map.keySet()) {
            double virNet = virNetPre*(map.get(key) + 1);
            virNetPre = virNet;
            map.put(key, virNet);
        }

        return map;
    }


    /**
     * 求置信度对应的临界值
     */
    public static double getNormsinv(double p) {
        double LOW = 0.02425;
        double HIGH = 0.97575;

        double a[] = {-3.969683028665376e+01, 2.209460984245205e+02,
                -2.759285104469687e+02, 1.383577518672690e+02,
                -3.066479806614716e+01, 2.506628277459239e+00};

        double b[] = {-5.447609879822406e+01, 1.615858368580409e+02,
                -1.556989798598866e+02, 6.680131188771972e+01,
                -1.328068155288572e+01};

        double c[] = {-7.784894002430293e-03, -3.223964580411365e-01,
                -2.400758277161838e+00, -2.549732539343734e+00,
                4.374664141464968e+00, 2.938163982698783e+00};

        double d[] = {7.784695709041462e-03, 3.224671290700398e-01,
                2.445134137142996e+00, 3.754408661907416e+00};

        double q, r;

        if (p < LOW) {
            q = Math.sqrt(-2 * Math.log(p));
            return (((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4])
                    * q + c[5])
                    / ((((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1);
        } else if (p > HIGH) {
            q = Math.sqrt(-2 * Math.log(1 - p));
            return -(((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4])
                    * q + c[5])
                    / ((((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1);
        } else {
            q = p - 0.5;
            r = q * q;
            return (((((a[0] * r + a[1]) * r + a[2]) * r + a[3]) * r + a[4])
                    * r + a[5])
                    * q
                    / (((((b[0] * r + b[1]) * r + b[2]) * r + b[3]) * r + b[4])
                    * r + 1);
        }
    }

    /**
     * 求每一天的链接系数-股票
     */
    public static TreeMap<String, Double> getLinkCoefWithDate(WrappedArray<Double> arg1, WrappedArray<String> arg2) {
        int size = arg1.size();
        TreeMap<String, Double> map = new TreeMap<>();
        for (int i = 0; i < size; i++) {
            map.put(arg2.apply(i), arg1.apply(i));
        }

        TreeMap<String, Double> outMap = new TreeMap<>();
        int count = 0;
        double res = 0.0;
        double beforeRc = 0;

        for (Map.Entry<String, Double> entry : map.entrySet()) {
            String key = entry.getKey();
            Double value = entry.getValue();
            if (count == 0) {
                res = 1;
                beforeRc = value;
            } else {
                res *= (1 + beforeRc);
                beforeRc = value;
            }

            outMap.put(key, res);
            count++;
        }
        return outMap;
    }

    /**
     * 求每一天的链接系数-债券业绩归因
     */
    public static TreeMap<String, Double> getLinkCoefWithDateByBond(WrappedArray<Double> arg1, WrappedArray<String> arg2) {
        int size = arg1.size();
        TreeMap<String, Double> treeMap = new TreeMap<>();
        for (int i = 0; i < size; i++) {
            treeMap.put(arg2.apply(i), arg1.apply(i));
        }

        TreeMap<String, Double> outMap = new TreeMap<>();

        int count = 0;
        double rate = 0.0;
        double res = 0.0;

        for (Map.Entry<String, Double> entry : treeMap.entrySet()) {
            String key = entry.getKey();
            Double value = entry.getValue();

            if (count == 0) {
                res = 1;
                rate = value * res;
            } else {
                res *= (1 + rate);
                rate = value * res;
            }

            outMap.put(key, res);

            count++;

        }

        return outMap;
    }
}